package it.uniroma2.sk;

import java.util.Arrays;
import java.util.HashMap;

public class SequenceKernel<T> {
	
	private double lambda = 1;
	private double lambdaSq = 1;
	
	// A bound of 0 means unbounded 
	private int bound = 0;
	
	public SequenceKernel() {
		lambdaSq = lambda*lambda;
	}
	
	public SequenceKernel(double lambda) {
		setLambda(lambda);
	}
	
	private HashMap<String, Double> map = new HashMap<String, Double>();
	
	public double value(T[] a, T[] b) {
		map.clear();
		double sum = 0;
		int bound = this.bound == 0 ? Math.min(a.length, b.length) : this.bound;
		for (int i=1; i<=bound; i++)
			sum += value(a, b, i);
		return sum;
	}

	public double value(T[] sx, T[] t, int n) {
		if (sx.length < n || t.length < n)
			return 0;
		String key = "0\t"+n+"\t"+Arrays.hashCode(sx)+"\t"+Arrays.hashCode(t);
		if (map.containsKey(key))
			return map.get(key);
		T[] s = Arrays.copyOfRange(sx, 0, sx.length-1);
		double sum = value(s, t, n);
		T c = sx[sx.length-1];
		for (int j=0; j<t.length; j++) {
			if (t[j].equals(c))
				sum += lambdaSq * K1(s, Arrays.copyOfRange(t, 0, j), n-1);
		}
		map.put(key, sum);
		return sum;
	}
	
	private double K1(T[] sx, T[] t, int n) {
		if (n == 0) 
			return 1;
		if (sx.length < n || t.length < n)
			return 0;
		String key = "1\t"+n+"\t"+Arrays.hashCode(sx)+"\t"+Arrays.hashCode(t);
		if (map.containsKey(key))
			return map.get(key);
		T[] s = Arrays.copyOfRange(sx, 0, sx.length-1);
		double sum = lambda * K1(s, t, n) + K2(sx, t, n);
		map.put(key, sum);
		return sum;
	}
	
	private double K2(T[] sx, T[] tu, int n) {
		if (tu.length < n || sx.length < n)
			return 0;
		String key = "2\t"+n+"\t"+Arrays.hashCode(sx)+"\t"+Arrays.hashCode(tu);
		if (map.containsKey(key))
			return map.get(key);
		T[] t = Arrays.copyOfRange(tu, 0, tu.length-1);
		double sum = lambda * K2(sx,t,n);
		if (sx[sx.length-1].equals(tu[tu.length-1])) {
			T[] s = Arrays.copyOfRange(sx, 0, sx.length-1);
			sum = sum + lambdaSq * K1(s, t, n-1);
		}
		map.put(key, sum);
		return sum;
	}

	public double getLambda() {
		return lambda;
	}

	public void setLambda(double lambda) {
		this.lambda = lambda;
		lambdaSq = lambda*lambda;
	}

	public int getBound() {
		return bound;
	}

	public void setBound(int bound) {
		this.bound = bound;
	}

}
